import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import cartopy.crs as ccrs
from pyproj import Proj, Transformer, CRS
from matplotlib.colors import BoundaryNorm
from descartes import PolygonPatch
from wrf import getvar, get_cartopy, to_np
from netCDF4 import Dataset
from cartopy.io.img_tiles import GoogleTiles
from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter

from plot_marshall_functions import ext2bbox, load_baseimage_any
from plot_marshall_functions import gjson_feature_geometry, gjson2poly
from plot_marshall_functions import refine_dom, read_wrffr
from plot_marshall_functions import pyproj_wrf2wgs, cartopy_defwrfproj, mplext_outer

# --------------------------------------------------------------------
# --------------------------------------------------------------------
# Defaults

thisrun = 'ysu-les-large'
plottype = 'area'
nn_idx = 5  # number of wrfout-files to accumulate - set it to 0 to plot 1 file
n_idx = nn_idx
init_idx = 0  # file index to begin accumulation

geogrid_files = '/glade/scratch/frediani/WRF-Fire/runs/marshall/wps-tim/geo_em.d02.nc'
fpath = os.getcwd()
print("File out path:", fpath)

g = 9.81

# Date/time strings

# yapf:disable
ddhh = [
    '30_18:00', '30_19:00', '30_19:45', '30_22:05', '30_23:00', '31_02:30'
#    '30_19:00', '30_19:45', '30_22:05', '31_02:30'
]

ddhhlocal = [
  '30_11:00',  '30_12:00',  '30_12:45',  '30_15:05',  '30_16:00',  '30_19:30'
#  '30_12:00',  '30_12:45',  '30_15:05',  '30_19:30'
]

# --------------------------------------------------------------------
# --------------------------------------------------------------------
# Read WPS Geographic Data

f0 = geogrid_files

# Cartopy includes the datum in the proj4 string - use this for projected plotting
wpwrfcrs = cartopy_defwrfproj(f0)
#pltcrs = ccrs.epsg(3857)

# Conversion and transformations using PyProj
wrfcrs, mapcrs, wrfext, mapext, latlonext, wrfxy, mapxy, dxy = pyproj_wrf2wgs(wf=f0)
wrfdx = dxy[0]

# --------------------------------------------------------------------
# --------------------------------------------------------------------
# Read fire perim json file

perim1 = gjson_feature_geometry(filename='marshall.json')

# --------------------------------------------------------------------
# --------------------------------------------------------------------
# Base Image elements

zext = (mplext_outer(mapext) +
        np.array([-dxy[0] / 2., +dxy[1] / 2., -dxy[0] / 2., +dxy[1] / 2.])).tolist()

# Convert from lon/lat to projected coordinates in meters
map2meters = Proj(mapcrs, preserve_units=False)

# Polygon Patch (Perimeter / shape file)
polygon = [gjson2poly(
    ft=ii,
    prjtrans=map2meters,
) for ii in perim1]

# Background
baseext = (np.array(zext) + np.array([3200., -3200.0, 3200.0, -3200.0])).tolist()
base = load_baseimage_any('USGSTopo', bbox=ext2bbox(*baseext))


# --------------------------------------------------------------------
# --------------------------------------------------------------------
for ii in np.arange(nn_idx):
    for jj in np.arange(1):
        if jj == 1:
            case = 'big2_maria_run9'
        else:
            case = 'big2_maria_run9_add_spotting'
        wrf_files = '/glade/scratch/tjuliano/fire/marshall/joinwrfh/1min/' + case + '/fire'
        wrf_files2 = '/glade/scratch/tjuliano/fire/marshall/joinwrfh/1min/' + case + '/spotting'
        wrf_files3 = '/glade/scratch/tjuliano/fire/marshall/joinwrfh/1min/' + case + '/met'
        print(f'thisrun: {thisrun} plottype: {plottype}\ninit_idx: {init_idx} n_idx: {n_idx}')

# --------------------------------------------------------------------
# --------------------------------------------------------------------

        plt.ioff()
#plt.ion()

# --------------------------------------------------------------------
# --------------------------------------------------------------------
# Figure name

#fpath = os.getcwd() + f'/plots_{thisrun}/'
#if not os.path.exists(fpath):
#    os.mkdir(fpath)

        fn0 = thisrun.replace('/', '_')
        fn1 = plottype
        fn2 = f'{init_idx:02d}-{n_idx:02d}'
        fn3 = 'dt' + f'{(init_idx/4):02.2f}' + '-' + f'{(init_idx+n_idx)/4:02.2f}'
        fname = '_'.join([fn0, fn1, fn2, fn3])
        fname = fname.replace('.','_') + '_panel.png'

        print("File out name:", fname)

# --------------------------------------------------------------------
# --------------------------------------------------------------------
# Read WRF Data

        fwrf = wrf_files
        fwrf2 = wrf_files2
        fwrf3 = wrf_files3
        fnamestr = 'wrfout_d02_2021-12-{}:00'

        farea = np.zeros(wrfxy[0].shape)
        refarea = np.zeros(4 * (np.array(farea.shape) + 1))

        if plottype == 'area':
            fw1 = fnamestr.format(ddhh[init_idx + n_idx])
            print("Area:", fw1)
            # farea = read_wrffr(os.path.join(fwrf, fw1), var='FS_FIRE_AREA')

            print("Refined Area:", fw1)
            refarea = read_wrffr(os.path.join(fwrf, fw1), var='FIRE_AREA')[0:-3, 0:-3]
            # coordinates don't exist in the netcdf file for the last 3 points
            refarea = refarea[0:-4, 0:-4]  # remove gridpoints beyond domain edges

            xlat2 = read_wrffr(os.path.join(fwrf2, fw1), var='XLAT')
            xlon2 = read_wrffr(os.path.join(fwrf2, fw1), var='XLONG')

            sp_sum_dep = read_wrffr(os.path.join(fwrf2, fw1), var='FS_COUNT_LANDED_HIST')
            # coordinates don't exist in the netcdf file for the last 3 points
#            sp_sum_dep = sp_sum_dep[0:-4, 0:-4]  # remove gridpoints beyond domain edges
            sp_sum_dep = 100.*(sp_sum_dep/np.nansum(sp_sum_dep.ravel()))
            #sp_sum_dep[sp_sum_dep<=0] = np.nan
            idxx = np.where(sp_sum_dep.ravel()<=0)[0]
            xlat2 = np.delete(xlat2.ravel(),idxx)
            xlon2 = np.delete(xlon2.ravel(),idxx)

            xlat = read_wrffr(os.path.join(fwrf2, fw1), var='XLAT')
            xlon = read_wrffr(os.path.join(fwrf2, fw1), var='XLONG')

            with Dataset(fwrf3 + '/' + fw1) as nc:
                u        = getvar(nc, 'ua').values
                u2       = getvar(nc, 'ua')
                v        = getvar(nc, 'va').values
                u10      = getvar(nc, 'U10').values
                v10      = getvar(nc, 'V10').values
                th       = getvar(nc, 'T').values
                th       = th + 300.
                ph       = getvar(nc, 'PH').values
                phb      = getvar(nc,'PHB').values
                zht      = ((ph[1:]+ph[0:-1])/2. +(phb[1:]+phb[0:-1])/2.)/9.81
                hgt      = getvar(nc,'HGT').values

            wrf_proj = get_cartopy(u2)

#        plt.figure(figsize=(10, 10))
#        ax = plt.subplot(111, projection=pltcrs)
#        plt.pcolormesh(xlat, xlon, sp_sum_dep, cmap='inferno', zorder=2, transform=wpwrfcrs)
#        plt.show()

# --------------------------------------------------------------------
# --------------------------------------------------------------------
# Plot

        # Map Imagery Download
        class ShadedReliefESRI(GoogleTiles):
            # shaded relief
            def _image_url(self, tile):
                x, y, z = tile
                url = ('https://server.arcgisonline.com/ArcGIS/rest/services/' \
                       'World_Street_Map/MapServer/tile/{z}/{y}/{x}.jpg').format(
                       z=z, y=y, x=x)
                return url

        #fig = plt.figure(figsize=(10, 10))
        #ax = fig.add_subplot(111, projection=pltcrs)
        if ii == 0 and jj == 0:
            fig = plt.figure(figsize=(10, 12))
#        ax = plt.subplot(111, projection=pltcrs)
        if jj == 0:
            ax = plt.subplot(3,2,n_idx, projection=ShadedReliefESRI().crs)
        else:
            ax = plt.subplot(3,2,(2*n_idx), projection=ShadedReliefESRI().crs)
#        ax2 = plt.axes(projection=wrf_proj)
        textdict = dict(c='k', va='center', transform=ax.transAxes, fontsize=12)

        ax.add_image(ShadedReliefESRI(), 12, interpolation='spline36') #the value '11' is the scale, larger the domain smaller the scale should be. Increasing it also increases the computational demand to render.
#        ax.imshow(base, extent=baseext)

        ax.set_xlim([-11717000.139409136, -11701400.535952998])
        ax.set_ylim([4854271.4235172225, 4866000.624060913])

        inProj = CRS.from_epsg(3857)
        outProj = CRS.from_epsg(4326)
        transformer = Transformer.from_crs(inProj,outProj)
        ll_latlon = transformer.transform(-11701400.535952998, 4854271.4235172225)
        ur_latlon = transformer.transform(-11717000.139409136, 4866000.624060913)
        

# Observed perimeter
        pp1 = [
            ax.add_patch(PolygonPatch(ii, fc='None', ec='red', zorder=2, lw=3.))
            for ii in polygon
        ]
        plt.plot([],[],color='red',label='OBS Perim',zorder=2,linewidth=3)

# --------------------------------------------------------------------
# --------------------------------------------------------------------
# Plot fire area

        if plottype == 'area':

            if plottype == 'area':
                ll1 = 'Fire Area'

            fxlon, fxlat = refine_dom(wrfxy[0], wrfxy[1], 0.25)
            fucmap = mpl.cm.get_cmap('gray', 2)
            print(fxlon)

#    ax.pcolormesh(fxlon,
#                  fxlat,
#                  np.ma.masked_less_equal(refarea, 0.05),
#                  shading='nearest',
#                  alpha=0.75,
#                  cmap=fucmap,
#                  norm=BoundaryNorm([0, 1], ncolors=2, clip=True),
#                  transform=wpwrfcrs,
#                  rasterized=True,
#                  zorder=1)

            CS = ax.contour(fxlon, fxlat, refarea, [0.06], colors='k', zorder=3, linewidths=3., transform=wpwrfcrs)
            #CS.collections[0].set_label('WRF Perim')
            plt.plot([],[],color='k',label='WRF Perim',zorder=3,linewidth=3)

            #print (np.nansum(sp_sum_dep.ravel()))
            #p1 = ax.pcolormesh(xlon, xlat, sp_sum_dep, vmin=0, vmax=100, cmap='tab10', zorder=3, transform=ccrs.PlateCarree())
            #plt.colorbar(p1,fraction=0.040,pad=0.04)
            ax.scatter(xlon2,xlat2,s=50,marker='o',fc='orange',ec='gray',transform=ccrs.PlateCarree(),zorder=5,label='WRF Firebrands')

# --------------------------------------------------------------------
# --------------------------------------------------------------------
# --------------------------------------------------------------------
# --------------------------------------------------------------------
# --------------------------------------------------------------------
# --------------------------------------------------------------------
# --------------------------------------------------------------------
# --------------------------------------------------------------------
# ADD VECTORS
        skipx = 15
        skipy = 8
        scale = 70  # bigger makes arrows smaller
        barbw = 0.03

        if ii == 0 and jj == 0:
            wrfinput_files = '/glade/scratch/tjuliano/fire/sagehen/WRF-Fire-LEAPHI/WRF/test/marshall/2021123012_maria/wrfinput_d02'
            dataset3   = Dataset(wrfinput_files, "r")
            cosalpha = dataset3.variables['COSALPHA'][-1,:,:]
            sinalpha = dataset3.variables['SINALPHA'][-1,:,:]
            lat = getvar(dataset3,"lat",meta=False)
            lon = getvar(dataset3,"lon",meta=False)

        uearth1000 = u10*cosalpha - v10*sinalpha
        vearth1000 = v10*cosalpha + u10*sinalpha

        u_src_crs1000 = to_np(uearth1000) / np.cos(to_np(lat) / 180 * np.pi)
        v_src_crs1000 = to_np(vearth1000)
        magnitude1000 = np.sqrt(to_np(uearth1000)**2 + to_np(vearth1000)**2)
        magn_src_crs1000 = np.sqrt(u_src_crs1000**2 + v_src_crs1000**2)

        q = ax.quiver(to_np(lon[::skipy,::skipx]), to_np(lat[::skipy,::skipx]),
                    u_src_crs1000[::skipy, ::skipx] * magnitude1000[::skipy, ::skipx] / magn_src_crs1000[::skipy, ::skipx],
                    v_src_crs1000[::skipy, ::skipx] * magnitude1000[::skipy, ::skipx] / magn_src_crs1000[::skipy, ::skipx],
                    transform=ccrs.PlateCarree(), linewidth=1., zorder=4, units='inches', scale=scale, width=barbw,facecolor='limegreen',edgecolor='black')
        if n_idx == nn_idx:
            qk = ax.quiverkey (q, 1.65, 0.85, 15, '15 m/s', labelpos='E',fontproperties={'weight': 'bold','size':20},zorder=6)
# --------------------------------------------------------------------
# --------------------------------------------------------------------
# --------------------------------------------------------------------
# --------------------------------------------------------------------
# --------------------------------------------------------------------
# --------------------------------------------------------------------
# --------------------------------------------------------------------
# --------------------------------------------------------------------
# ADD FR
        loopidx = np.arange(5,np.shape(xlat)[0],1)
        mean_lata = np.empty(len(loopidx))
        fr_transa = np.empty(len(loopidx))
        count = 0
        for xx in loopidx:
            x = xx
            xlatv = xlat[x, :]
            xlonv = xlon[x, :]
            zv    = np.squeeze(zht[:,x, :])
            uv    = np.squeeze(u[:,x,:])
            vv    = np.squeeze(v[:,x,:])
            thv   = np.squeeze(th[:,x,:])
            hgtv  = hgt[x, :]

            lata = np.empty(np.shape(thv)[1])
            pblh = np.empty(np.shape(thv)[1])
            mean_wspd = np.empty(np.shape(thv)[1])
            gp = np.empty(np.shape(thv)[1])
            fr = np.empty(np.shape(thv)[1])
            for jj in np.arange(np.shape(thv)[1]):
                lata[jj] = xlatv[jj]
                thv_grad = (thv[1:36,jj] - thv[0:35,jj]) / (zv[1:36,jj] - zv[0:35,jj])
                thv_grad_max = np.argmax(thv_grad)
                pblh[jj] = zv[thv_grad_max,jj] - hgtv[jj]
                #print ('Found PBL height at ' + str(pblh[jj]) + ' m')
                #print ('Found FT height at ' + str(zv[thv_grad_max+2,jj] - hgtv[jj]) + ' m')
                mean_wspd[jj] = np.mean(np.sqrt(pow(uv[0:thv_grad_max,jj],2.)+pow(vv[0:thv_grad_max,jj],2.)))
                #print ('Found mean PBL windspeed ' + str(mean_wspd) + ' m/s')
                mean_ptemp = np.mean(thv[0:thv_grad_max,jj])
                #print ('Found mean PBL theta ' + str(mean_ptemp) + ' K')
                ft_ptemp = thv[thv_grad_max+2,jj]
                #print ('Found FT theta ' + str(ft_ptemp) + ' K')
                gp[jj] = g*(ft_ptemp-mean_ptemp)/mean_ptemp
                fr[jj] = mean_wspd[jj] / np.sqrt(gp[jj]*pblh[jj])

            mean_lat = np.mean(lata)

            idxx = np.where(fr<1.)[0]
            fr_trans = xlonv[idxx[0]]

#            print (count, mean_lat, fr_trans)
            mean_lata[count] = mean_lat
            fr_transa[count] = fr_trans
            count+=1

        ax.scatter(fr_transa[0:-1:2],mean_lata[0:-1:2],s=50,marker='d',fc='blue',ec='gray',transform=ccrs.PlateCarree(),zorder=5,label='$Fr$ Transition')
#            ax.scatter(fr_trans,mean_lat,s=80,marker='d',color='magenta',transform=wrfcrs,zorder=5)
#            ax.scatter(fr_trans,mean_lat,s=80,marker='d',color='magenta',transform=mapcrs,zorder=5)

        ax.scatter(-105.230189,39.956029,s=400,marker='*',fc='magenta',ec='gray',transform=ccrs.PlateCarree(),zorder=5,label='Ign Location')

# --------------------------------------------------------------------
# --------------------------------------------------------------------

        ll0 = fnamestr.format(ddhh[init_idx + n_idx]).replace('wrfout_d02_', '')[:-3]

        dt_init = (init_idx - 1) / 4  # -1 because it ends at this output time
        dt_hours = (n_idx) / 4
        ll2 = 'Accum Interval: ' + f'{dt_hours}' + ' hours'

#    tt2 = plt.text(0.02,
#                   0.03,
#                   f'{ll0}UTC  ({ddhhlocal[init_idx+n_idx]} MT)',
#                   ha='left',
#                   **textdict)
#    tt2.set_bbox(dict(facecolor='white', alpha=0.65, edgecolor='none'))
#    tt3 = plt.text(0.995, 0.03, ll2, ha='right', **textdict)
#    tt3.set_bbox(dict(facecolor='white', alpha=0.65, edgecolor='none'))

#    plt.title(thisrun.split('/')[0], **{**textdict, **dict(fontsize=14)})
#        plt.title(f'{ll0}UTC  ({ddhhlocal[init_idx+n_idx]} MT)',**{**textdict, **dict(fontsize=14)})
        plt.title(f'{ll0} UTC',**{**textdict, **dict(fontsize=13)})
        if n_idx == 4 or n_idx == 5:
            ax.set_xticks(np.linspace(-11717000.139409136,-11701400.535952998,7)) # set longitude indicato
            ax.set_xticklabels('{:0.2f}'.format(abs(float(i)))+u'\N{DEGREE SIGN}'+' W' for i in np.linspace(ur_latlon[1],ll_latlon[1],7))
        if n_idx == 1 or n_idx ==3 or n_idx == 5:
            ax.set_yticks(np.linspace(4866000.624060913,4854271.4235172225,7)) # set latitude indicators
            ax.set_yticklabels('{:0.2f}'.format(abs(float(i)))+u'\N{DEGREE SIGN}'+' N' for i in np.linspace(ur_latlon[0],ll_latlon[0],7))
        ax.tick_params(axis="both", direction='in')
        ax.xaxis.set_tick_params(labelsize=11, rotation=30)
        ax.yaxis.set_tick_params(labelsize=11)
# --------------------------------------------------------------------
# --------------------------------------------------------------------

#        plt.show()

    if ii == 0:
        handles, labels = ax.get_legend_handles_labels()
        fig.legend(handles,labels,loc='center right',bbox_to_anchor=(0.92,0.2),bbox_transform=plt.gcf().transFigure,ncol=1,fontsize=18)
#    plt.show()

    n_idx = n_idx - 1

#fig.savefig(fpath + '/' + fname, bbox_inches='tight', dpi=200, pad_inches=0)
print(fpath + '/' + fname)
#plt.tight_layout()
plt.savefig(case + '_' + fname, dpi=600, bbox_inches='tight')

print(fpath + '/' + case + '_' + fname)

#plt.close()
